import jax
import jax.numpy as jnp
import tensorflow as tf
import tensorflow_datasets as tfds
import os
import numpy as np


def get_metadata(config):
    metadata = {}
    if config.eval.task in ['mnist', 'fmnist', 'fmnist_conv']: 
        metadata['num_train'] = 48000 # 60000 * 0.8
        metadata['num_valid'] = 12000
        metadata['num_test'] =  10000
        metadata['num_classes'] = 10
        metadata['shape'] = (28, 28, 1)
    elif config.eval.task in ['ResNet', 'cifar100']:
        metadata['num_train'] = 40000
        metadata['num_valid'] = 10000
        metadata['num_test'] =  10000
        metadata['num_classes'] = 10 if config.eval.task == 'cifar' else 100 # this is wrong
        metadata['shape'] = (32, 32, 3)
    elif config.eval.task in ['tiny64_resnet56']:
        metadata['num_train'] = 100000
        metadata['num_test'] =  10000
        metadata['num_classes'] = 200
        metadata['shape'] = (32, 32, 3)
    else:
        raise ValueError(f'Invalid data {config.eval.task}')

    metadata['num_train_batches'] = metadata['num_train'] // config.eval.batch_size
    metadata['num_test_batches'] = metadata['num_test'] // config.eval.batch_size

    return metadata

   
def get_nist(name, split='train', num_data=None):
    
    ds = tfds.load(name,
            split=split if num_data is None else f'{split}[:{num_data}]',
            as_supervised=True)

    def preprocess(image, label):
        image = tf.cast(image, tf.float32) / 255.0
        return image, label

    ds = ds.map(preprocess,
            num_parallel_calls=tf.data.experimental.AUTOTUNE).cache()
    ds = ds.batch(num_data or 60000)
    ds = tfds.as_numpy(ds)

    return jax.tree_util.tree_map(jnp.asarray, next(iter(ds)))


def get_cifar(name:str, split:str,  data_augmentation:int=0, num_data=None):
    
    ds,ds_info = tfds.load(name,
            split=(split if num_data is None else f'{split}[:{num_data}]'),
            as_supervised=True,
            with_info = True)
    
    image_shape = ds_info.features['image'].shape
    
    def preprocess(image, label):
      image_mean = tf.constant([[[0.4914, 0.4822, 0.4465]]])
      image_std =  tf.constant([[[0.2470, 0.2435, 0.2616]]])

      if data_augmentation==1:
          image = tf.pad(image, [[2, 2], [2, 2], [0, 0]],'CONSTANT') # crop_padding = 4 , constant? REFLECT? 아 ㅅㅂ 왜  성능이 떨어지노
          image = tf.image.random_crop(image, image_shape) # crop
          image = tf.image.random_flip_left_right(image)
      else:
          pass

      image = tf.cast(image, tf.float32) / 255.0 # 이미 scale 이 0,1사이 scale로 되어있는거같은데?
      image = (image - image_mean) / image_std

      return image, label

    ds = ds.map(preprocess)
    ds = ds.batch(60000) # 60000 
    ds = tfds.as_numpy(ds)

    return jax.tree_util.tree_map(jnp.asarray, next(iter(ds)))

def get_tiny(name:str, split:str,  data_augmentation:int=0, num_data=None):
    if name in ['tiny32', 'tiny32_resnet50', 'tiny32_resnet20']:
        dataset = 'TinyImageNet200_x32'
    elif name in ['tiny64', 'tiny64_resnet50','tiny64_resnet56']:
        dataset = 'TinyImageNet200_x64'
    else:
        raise NotImplementedError('Inappropriate dataset name')
    
    data = jnp.load(f'./dataset/{dataset}/{split}_images.npy')
    label = jnp.load(f'./dataset/{dataset}/{split}_labels.npy')

    ds = tf.data.Dataset.from_tensor_slices((data, label))

    def preprocess(image, label):
        image = tf.cast(image, tf.float32) / 255.0
        return image, label
    
    ds = ds.map(preprocess)
    ds = ds.batch(num_data)
    ds = tfds.as_numpy(ds)
    
    return jax.tree_util.tree_map(jnp.asarray, next(iter(ds)))


def get_data(config, metadata, split='train'):
    if split == 'train':
        num_data = metadata['num_train']
    elif split == 'test':
        num_data = metadata['num_test']
    else:
        raise ValueError(f'Unknown split type ( valid is not currently available)')

    if config.eval.task in ['mnist', 'fmnist', 'fmnist_conv']:
        if config.eval.task in ['fmnist', 'fmnist_conv']:
            data = 'fashion_mnist'
            return get_nist(data, split=split, num_data=num_data)
        elif config.eval.task in ['mnist']:
            data = 'mnist'
            return get_nist(data, split=split, num_data=num_data)
    elif config.eval.task in ['ResNet']:
        data = 'cifar10'
        return get_cifar(name=data, split=split, num_data=num_data)
    elif config.eval.task in ['cifar100']:
        data = 'cifar100'
        return get_cifar(name=data, split=split, num_data=num_data)
    elif config.eval.task in ['tiny64_resnet56']:
        return get_tiny(name=config.eval.task, split=split, num_data=num_data)
    else:
        raise ValueError(f'Unknown dataset {config.eval.task}')


def get_indices(metadata:dict, num_step:int, 
                     batch_size:int, key:jnp.array, split:str='train'):
    
    '''
    get batch corresponding to 1 epoch
    
    output['image'] :  [num_step , batch_size , img_size]
    output['label'] :  [num_step, batch_size , label_num]
    '''
    if split == 'train':
        num_data = metadata['num_train']
    elif split == 'test':
        num_data = metadata['num_test']
    elif split == 'valid':
        num_data = metadata['num_valid']
    else:
        raise ValueError(f'Unknown split type ( valid is not currently available)')
    
    key, perm_key = jax.random.split(key, 2)
    indices = jax.random.permutation(perm_key,jnp.arange(int(num_step*batch_size)))
    indices = jax.tree_util.tree_map(
            lambda x: x.reshape(
                (num_step,batch_size)),
            indices)
    
    return indices, perm_key